#!/usr/bin/env python
# -*- coding:utf-8 -*-

from unicorn import *
from unicorn.x86_const import *
from capstone import *
import struct,collections
from emuapi.windows.peinfo import PeInfo
from emuapi.arch.archemu import ArchEmu
from emuapi.arch.register import Register
from emuapi.windows.teb import teb_struct
from emuapi.windows.wdll import is_dll_exists,dll_dict
from lib.utils import printer

class WinEmu(object):
    def __init__(self,code,size = 1,lDll=[],ispe=False,isteb=False,isgdt=False,mode=32):
        self.codeZone = []
        self.dllZone = []
        self.stackZone = []
        self.win32Dict = {}
        self.codeBase = 0x00100000
        self.stackBase = 0x0
        self.stackSize = 0x0
        self.ispe = ispe
        self.isteb = isteb
        self.isgdt = isgdt
        self.mode = mode
        self.code = code
        self.lDll = lDll
        if ispe:
            self.pe = PeInfo(self.code)
        self.size = size

    def build_emu_envir(self):
        if self.ispe:
            if self.pe.load_pe():
                self.mode = self.pe.get_arch()
                self.init_unicorn()
                self.base = self.pe.get_imageBase()
                self.pe_mem_map()
                self.pe_fake_win32_dll()
                if self.isteb:
                    self.pe_fake_teb()
                if self.isgdt:
                    self.pe_fake_gdt()
                self.create_stack()
            else:
                print("PE文件加载失败!")
        else:
            self.init_unicorn()
            self.code_mem_map()
            self.data_mem_map()#临时使用
            self.code_fake_teb_pib_ldr()
            if self.lDll:
                self.code_fake_win32_dll()
            self.create_stack()
        self.show_build_info()
    def show_build_info(self):
        self.fix_zone()
        printer.plus("-------------------- 虚拟环境信息 --------------------",flag='')
        printer.plus('>>>CPU架构  --> %s [%s]' % ("X86",str(self.mode) + "位"),flag='')
        printer.plus('>>>代码区   --> 0x%x ~ 0x%x' % (self.codeZone[0],self.codeZone[-1]),flag='')
        printer.plus('>>>库函数区 --> 0x%x ~ 0x%x' % (self.dllZone[0],self.dllZone[-1]),flag='')
        printer.plus('>>>栈区     --> 0x%x ~ 0x%x' % (self.stackZone[0],self.stackZone[-1]),flag='')
        printer.plus("-----------------------------------------------------",flag='')
        print("")
    def init_unicorn(self):
        archemu = ArchEmu(self.mode)
        self.mu = archemu.get_arch_x86()
        if self.mode == 32:
            self.reg = Register(self.mu,'X86')
        elif self.mode == 64:
            self.reg = Register(self.mu,'X64')
    
    def get_mapentry(self):
        mapEntry = collections.namedtuple('MapEntry',['va','size'])
        ret = []
        sections = self.pe.get_sections()
        for section in sections:
            rva = section.VirtualAddress
            va = self.base + rva
            size = section.Misc_VirtualSize
            ret.append(mapEntry(va,size))
        return ret
    def get_bytes(self,va,length):
        rva = va -self.base
        data = self.pe.get_data(rva,length)
        return data

    def va(self,rva):
        addr = rva + self.base
        return addr
    def rva(self,va):
        addr = va - self.base
        return addr
    def align(self,value,alignment):
        if value % alignment == 0:
            return value
        alignValue = value + (alignment - (value % alignment))
        return alignValue
    
    def pe_mem_map(self):
        alignSize = self.pe.get_SectionAlignment()
        for section in self.get_mapentry():
            self.mu.mem_map(section.va,self.align(section.size,alignSize))
            #print(hex(section.va) + "~" + hex(section.va + self.align(section.size,alignSize)))
            self.codeZone.append(section.va)
            self.codeZone.append(section.va + self.align(section.size,alignSize))
            dataBytes = self.get_bytes(section.va,section.size)
            self.mu.mem_write(section.va,dataBytes)
    def code_mem_map(self):
        memSize = 1024 * 1024 * self.size
        self.mu.mem_map(self.codeBase,memSize)
        self.codeZone.append(self.codeBase)
        self.codeZone.append(self.codeBase + memSize)
        self.mu.mem_write(self.codeBase,self.code)
    def data_mem_map(self):
        #51testing teach
        dataBase = 0x5000
        memSize = 0x1000
        self.mu.mem_map(dataBase,memSize)
        self.mu.mem_write(dataBase,b'\x00' * memSize)
        self.mu.mem_write(dataBase,b'51testingQuan')
    def create_stack(self):
        i386 = True
        if self.mode == 32:
            self.stackBase = 0x00300000
            self.stackSize = 0x00100000
        elif self.mode == 64:
            i386 = False
            self.stackBase = 0xffffffff00000000
            self.stackSize = 0x0000000000100000
        else:
            self.stackBase = 0x00300000
            self.stackSize = 0x00100000
        self.mu.mem_map(self.stackBase,self.stackSize)
        self.stackZone.append(self.stackBase)
        self.stackZone.append(self.stackBase + self.stackSize)
        self.mu.mem_write(
            self.stackBase,
            b'\x00' * self.stackSize
        )
        if i386:
            self.mu.reg_write(UC_X86_REG_ESP,self.stackBase + 0x800)
            self.mu.reg_write(UC_X86_REG_EBP,self.stackBase + 0x1000)
        else:
            self.mu.reg_write(UC_X86_REG_RSP,self.stackBase + 0x8000)
            self.mu.reg_write(UC_X86_REG_RBP,self.stackBase + 0x10000)

    def pe_fake_win32_dll(self):
        sysBase = 0xff00000
        dllSize = 0x1000
        iat = self.pe.get_iat()
        for entry in iat:
            dllDict = {}
            dllDict['apiDict'] = {}
            dllDict['dllName'] = entry.dll.decode()
            dllDict['dllBase'] = sysBase + len(self.win32Dict) * dllSize
            dllDict['dllLimit'] = dllDict['dllBase'] + dllSize -1
            self.mu.mem_map(dllDict['dllBase'],dllSize)
            self.dllZone.append(dllDict['dllBase'])
            self.dllZone.append(dllDict['dllBase'] + dllSize)
            self.mu.mem_write(dllDict['dllBase'],b'\xC3'*dllSize) #ret
            for imp in entry.imports:
                apiName = imp.name.decode()
                apiAddr = dllDict['dllBase'] + len(dllDict['apiDict'])
                self.mu.mem_write(imp.address,struct.pack('<I',apiAddr))
                dllDict['apiDict'][apiAddr] = apiName
            self.win32Dict[dllDict['dllName']] = dllDict

    def code_fake_win32_dll(self):
        sysBase = 0xff00000
        dllSize = 0x1000
        isFlag = True
        for dll in self.lDll:
            if not is_dll_exists(dll):
                isFlag = False
                break
        if isFlag:
            dllsStart = []
            dllsLength = []
            for dll in self.lDll:
                imports,maddr = dll_dict(dll)
                dllBase = self.align(maddr[0],0x1000) - 0x1000
                dllSize = self.align(maddr[-1]-maddr[0],0x1000)
                dllsStart.append(dllBase)
                dllsLength.append(dllSize)
            dllsStart.sort()
            dllsLength.sort()
            lastMem = self.align(0x139b000,0x1000)
            #print(hex(lastMem))
            self.mu.mem_map(0x7dc65000,lastMem)
            #print(hex(0x7dc65000+lastMem))
            #print(hex(dllsStart[0]) + "~" + hex(dllsStart[0] +self.align(dllsLength[0]+dllsLength[1],0x1000)))

            for dll in self.lDll:
                imports,maddr = dll_dict(dll)
                #print(hex(dllBase) + "~" + hex(dlllimit))
                #self.mu.mem_map(dllBase,dllSize)
                dllDict = {}
                dllDict['apiDict'] = {}
                dllDict['dllName'] = dll
                dllDict['dllBase'] = sysBase + len(self.win32Dict) * dllSize
                dllDict['dllLimit'] = dllDict['dllBase'] + dllSize -1
                #print(hex(dllDict['dllBase']) + "~" + hex(dllDict['dllBase']+dllSize))
                self.mu.mem_map(dllDict['dllBase'],dllSize)
                self.dllZone.append(dllDict['dllBase'])
                self.dllZone.append(dllDict['dllBase'] + dllSize)
                self.mu.mem_write(dllDict['dllBase'],b'\xC3' * dllSize) #ret
                for address,apiName in imports.items():
                    apiAddr = dllDict['dllBase'] + len(dllDict['apiDict'])
                    self.mu.mem_write(address,struct.pack('<I',apiAddr))
                    dllDict['apiDict'][apiAddr] = apiName
                self.win32Dict[dllDict['dllName']] = dllDict
    def pe_fake_teb(self):
        tebBase = 0
        pebBase = tebBase + 0x1000
        teb = teb_struct(
            -1,
            self.stack_base,
            self.stack_base - self.stack_size,
            0,
            0,
            0,
            tebBase,
            0,
            0xeeeeeeee,
            0xeeeeeeee,
            0,
            0,
            pebBase,
            0,
            0,
            0,
            0
        )
        tebBytes = bytes(teb)
        self.mu.mem_map(tebBase,1024 * 1024 * 2)
        self.mu.mem_write(tebBase,tebBytes)
        self.mu.reg_write(UC_X86_REG_FS,tebBytes)
    def code_fake_teb_pib_ldr(self):
        TEB = 0
        PEB = TEB + 0x30
        LDR = PEB + 0x0C
        self.mu.mem_map(TEB,0x2000)
        self.mu.mem_write(PEB,struct.pack('<i',PEB))
        self.mu.mem_write(LDR,struct.pack('<i',LDR))
        self.mu.reg_write(UC_X86_REG_FS,TEB)

    def pe_fake_gdt(self):
        pass
    
    def get_win32_api_by_addr(self,addr):
        for _,dll in self.win32Dict.items():
            if dll['dllBase'] <= addr <= dll['dllLimit']:
                apiName = dll['apiDict'].get(addr)
                return apiName
        return None
    def get_win32_api_dict(self):
        return self.win32Dict
    def get_valid_zone(self):
        codeZone = {}
        dllZone = {}
        stackZone = {}
        codeZone['base'] = self.codeZone[0]
        codeZone['limit'] = self.codeZone[-1]
        dllZone['base'] = self.dllZone[0]
        dllZone['limit'] = self.dllZone[-1]
        stackZone['base'] = self.stackZone[0]
        stackZone['limit'] = self.stackZone[-1]
        return codeZone,dllZone,stackZone

    def get_pe_run_range(self):
        addrEntry = self.pe.get_addrEntry()
        imageBase = self.pe.get_imageBase()
        return addrEntry,imageBase
    def fix_zone(self):
        self.codeZone = list(set(self.codeZone))
        self.dllZone = list(set(self.dllZone))
        self.stackZone = list(set(self.stackZone))
        self.codeZone.sort()
        self.dllZone.sort()
        self.stackZone.sort()
    def valid_zone(self,address):
        if self.codeZone or self.dllZone or self.stackZone:
            if self.codeZone[0] <= address <= self.codeZone[-1] or \
                    self.dllZone[0]<= address <= self.dllZone[-1] or \
                    self.stackZone[0] <= address <= self.stackZone[-1]:
                    return True
        return False
